#!/usr/bin/env python
# -*- coding: utf8 -*-

import yaml
import copy
import os
import re



def yaml_load(in_path, load_type='rb'):
    data = None
    with open(in_path, load_type) as f:
        data = yaml.safe_load(f)
    return data
    
def yaml_dump(in_path, in_data, dump_type='w', allow_unicode=False, default_flow_style=False):
    with open(in_path, dump_type) as f:
        yaml.safe_dump(in_data, f, default_flow_style=default_flow_style, allow_unicode=allow_unicode)

def yaml_power_load(in_path):
    '''
    1. 增加了构造标签: !file , 可以直接引用外部文件, 如下面例子
        aa: !file test/bb.yml
    '''
    def ___constructor_file(loader, node):
        file_path = os.path.dirname(loader.name)
        value = node.value
        if not os.path.isfile(value):
            value = os.path.join(file_path, value)
            if not os.path.isfile(value):
                return
        return yaml_load(value)
    yaml.add_constructor("!file", ___constructor_file, yaml.SafeLoader)
    if os.path.isfile(in_path):
        return yaml_load(in_path)

# ---------------------------------------------------------------------
def get_full_path(base_path, path):
    path = path.replace('\\', '/')
    base_path = base_path.replace('\\', '/')
    if not os.path.isfile(path):
        if path.startswith('./'):
            mat = re.match('^([./]*)(.*)', path)
            if mat:
                aa = mat.groups()[0]
                bb = mat.groups()[1]
                if aa == './':
                    path = os.path.join(base_path, bb).replace('\\', '/')
        elif path.startswith('../'):
            mat = re.match('^([../]*)(.*)', path)
            if mat:
                aa = mat.groups()[0]
                bb = mat.groups()[1]
                if aa.count('../')*3 == len(aa) and base_path.count('/') >= aa.count('../'):
                    for i in range(aa.count('../')):
                        base_path = os.path.dirname(base_path)
                    path = os.path.join(base_path, bb).replace('\\', '/')
        else:
            path = os.path.join(base_path, path).replace('\\', '/')

    if os.path.isfile(path):
        return path

#-----------------------------------------------------------------------------
def anchor_parse(in_data):
    '''
    解析@：只能引用数据里面已经存在的数据，引用只能在第一层级之间引用（即第一层key:value)
    '''
    find_ref = False
    keys = in_data.keys()
    for k in keys:
        _value = in_data[k]
        if isinstance(_value, str):
            if _value.startswith('@'):
                _value = _value.replace('\\', '/')
                if _value.count('/') > 0:
                    ref_key = _value[1:].replace('\\', '/').split('/')[0]
                    if ref_key not in keys:
                        raise ValueError('unknow @{}'.format(ref_key))
                    _new_value = _value.replace('@{}'.format(ref_key), in_data[ref_key])
                else:
                    ref_key = _value[1:]
                    if ref_key not in keys:
                        raise ValueError('unknow @{}'.format(ref_key))
                    _new_value = in_data[ref_key]
                in_data[k] = _new_value
                if isinstance(_value, str) and _new_value.startswith('@'):
                    find_ref = True
    if find_ref:
        anchor_parse(in_data)

def anchor_parse_lookup(in_data, lookup_dict):
    """
    通过对照数据解析原始数据中的锚 @ 和 <@
    """
    result = in_data
    if isinstance(in_data, list):
        result = []
        for _data in in_data:
            result.append(anchor_parse_lookup(_data, lookup_dict))
    elif isinstance(in_data, dict):
        result = {}
        if '<@' in in_data:
            _data = anchor_parse_lookup(in_data['<@'], lookup_dict)
            if not isinstance(_data, dict):
                raise ValueError("Error config: %s must be dict!" % in_data['<@'])
            for _k in _data:
                result[_k] = _data[_k]
        keys = in_data.keys()
        for _k in keys:
            if _k != '<@':
                _v = in_data[_k]
                _rv = anchor_parse_lookup(_v, lookup_dict)
                if _k in result:
                    if type(result[_k]) != type(_rv):
                        result[_k] = _rv
                    else:
                        if isinstance(_rv, dict):
                            result[_k].update(_rv)
                        elif isinstance(_rv, list):
                            for i in _rv:
                                if i not in result[_k]:
                                    result[_k].append(i)
                        else:
                            result[_k] = _rv
                else:
                    result[_k] = _rv
        
    # elif isinstance(data, basestring) and data.startswith("@"):
    elif isinstance(in_data, str) and in_data.startswith("@"):
        in_data.replace('\\', '/')
        if in_data.count('/') > 0:
            ref_key = in_data[1:].replace('\\', '/').split('/')[0]
            if ref_key not in lookup_dict:
                raise ValueError("Undefined Reference %s!" % ref_key)
            in_data = in_data.replace('@{}'.format(ref_key), lookup_dict[ref_key])
            if in_data.startswith("@"):
                anchor_parse_lookup(in_data, lookup_dict)
            else:
                result = in_data
        else:
            ref_key = in_data[1:]
            if ref_key not in lookup_dict:
                raise ValueError("Undefined Reference %s!" % ref_key)
            result = copy.deepcopy(lookup_dict[ref_key])
    return result

#-------------------------------------------
def yaml_reference_load(in_path):
    '''
    文件引用导入，可以引用外部文件，必须遵循固定格式:
        1. 无文件引用：
            include: ref_path
            config: ...
        2. 单文件引用：
            include: ref_path
            config: ...
        3. 多文件引用：
            includes: 
                - ref01_path
                - ref02_path
                ...
            config: ...
    该导入函数同时支持 @ 和 <@ 两种锚，详见
    '''
    _data = yaml_power_load(in_path)
    result_data = _data['config']
    lockup_data = {}
    includes_files = __includes_files(in_path, _data)
    if includes_files:
        for _file in includes_files:
            lockup_data.update(yaml_reference_load(_file))
    return anchor_parse_lookup(result_data, lockup_data)
 
def __includes_files(in_path, in_data):
    """
    解析所有引用文件路径路径
    """
    includes = []
    resolved_includes = set()
    if 'include' in in_data:
        includes.append( in_data['include'])
    if 'includes' in in_data:
        includes.extend( in_data['includes'])
    for include in includes:
        path = get_full_path(os.path.dirname(in_path), include)
        if path:
            resolved_includes.add(path)
    return list(resolved_includes)
        
#-------------------------------------------
def yaml_template_load(in_path):
    _data = yaml_power_load(in_path)
    result_data = {}
    result_data['keys'] = copy.deepcopy(_data.get('keys', {}))
    paths_data = _data.get('paths', {})
    anchor_parse(paths_data)
    result_data['paths'] = copy.deepcopy(paths_data)
    strings_data = _data.get('strings', {})
    anchor_parse(strings_data)
    result_data['strings'] = copy.deepcopy(strings_data)
    return result_data

    
